{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Vietnamese ULMFiT from scratch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline\n",
    "\n",
    "from fastai import *\n",
    "from fastai.text import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# bs=48\n",
    "# bs=24\n",
    "bs=128"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.set_device(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_path = Config.data_path()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This will create a `viwiki` folder, containing a `viwiki` text file with the wikipedia contents. (For other languages, replace `vi` with the appropriate code from the [list of wikipedias](https://meta.wikimedia.org/wiki/List_of_Wikipedias).)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "lang = 'vi'\n",
    "# lang = 'zh'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "name = f'{lang}wiki'\n",
    "path = data_path/name\n",
    "path.mkdir(exist_ok=True, parents=True)\n",
    "lm_fns = [f'{lang}_wt', f'{lang}_wt_vocab']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Vietnamese wikipedia model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "### Download data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "from nlputils import split_wiki,get_wiki"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "get_wiki(path,lang)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[PosixPath('/home/jhoward/data/zhwiki/docs'),\n",
       " PosixPath('/home/jhoward/data/zhwiki/zhwiki-latest-pages-articles.xml.bz2'),\n",
       " PosixPath('/home/jhoward/data/zhwiki/zh.cnf'),\n",
       " PosixPath('/home/jhoward/data/zhwiki/log'),\n",
       " PosixPath('/home/jhoward/data/zhwiki/zhwiki'),\n",
       " PosixPath('/home/jhoward/data/zhwiki/zhwiki-latest-pages-articles.xml'),\n",
       " PosixPath('/home/jhoward/data/zhwiki/wikiextractor')]"
      ]
     },
     "execution_count": 68,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "path.ls()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<doc id=\"13\" url=\"https://vi.wikipedia.org/wiki?curid=13\" title=\"Tiếng Việt\">\r\n",
      "Tiếng Việt\r\n",
      "\r\n",
      "Tiếng Việt, còn gọi tiếng Việt Nam hay Việt ngữ, là ngôn ngữ của người Việt (người Kinh) và là ngôn ngữ chính thức tại Việt Nam. Đây là tiếng mẹ đẻ của khoảng 85% dân cư Việt Nam, cùng với hơn 4 triệu Việt kiều. Tiếng Việt còn là ngôn ngữ thứ hai của các dân tộc thiểu số tại Việt Nam. Mặc dù tiếng Việt có một số từ vựng vay mượn từ tiếng Hán và trước đây dùng chữ Nôm – một hệ chữ viết dựa trên chữ Hán – để viết nhưng tiếng Việt được coi là một trong số các ngôn ngữ thuộc ngữ hệ Nam Á có số người nói nhiều nhất (nhiều hơn một số lần so với các ngôn ngữ khác cùng hệ cộng lại). Ngày nay, tiếng Việt dùng bảng chữ cái Latinh, gọi là chữ Quốc ngữ, cùng các dấu thanh để viết.\r\n"
     ]
    }
   ],
   "source": [
    "!head -n4 {path}/{name}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "hidden": true
   },
   "source": [
    "This function splits the single wikipedia file into a separate file per article. This is often easier to work with."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "dest = split_wiki(path,lang)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[PosixPath('/home/jhoward/data/viwiki/docs/Luis Suárez.txt'),\n",
       " PosixPath('/home/jhoward/data/viwiki/docs/Vitas.txt'),\n",
       " PosixPath('/home/jhoward/data/viwiki/docs/Chùa Hà.txt'),\n",
       " PosixPath('/home/jhoward/data/viwiki/docs/Đại Phái bộ Sứ thần.txt'),\n",
       " PosixPath('/home/jhoward/data/viwiki/docs/2 Broke Girls.txt')]"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dest.ls()[:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "# Use this to convert Chinese traditional to simplified characters\n",
    "# ls *.txt | parallel -I% opencc -i % -o ../zhsdocs/% -c t2s.json"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "### Create pretrained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "data = (TextList.from_folder(dest)\n",
    "            .split_by_rand_pct(0.1, seed=42)\n",
    "            .label_for_lm()           \n",
    "            .databunch(bs=bs, num_workers=1))\n",
    "\n",
    "data.save(f'{lang}_databunch')\n",
    "len(data.vocab.itos),len(data.train_ds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "data = load_data(path, f'{lang}_databunch', bs=bs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "hidden": true,
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "learn = language_model_learner(data, AWD_LSTM, drop_mult=0.5, pretrained=False).to_fp16()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "lr = 1e-2\n",
    "lr *= bs/48  # Scale learning rate by batch size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>3.436113</td>\n",
       "      <td>3.491434</td>\n",
       "      <td>0.366925</td>\n",
       "      <td>28:52</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>3.441240</td>\n",
       "      <td>3.544118</td>\n",
       "      <td>0.361326</td>\n",
       "      <td>28:33</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>3.571766</td>\n",
       "      <td>3.556932</td>\n",
       "      <td>0.358438</td>\n",
       "      <td>28:31</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>3.510540</td>\n",
       "      <td>3.519243</td>\n",
       "      <td>0.362278</td>\n",
       "      <td>28:27</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>3.447639</td>\n",
       "      <td>3.449320</td>\n",
       "      <td>0.369404</td>\n",
       "      <td>28:29</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>3.412284</td>\n",
       "      <td>3.406376</td>\n",
       "      <td>0.375022</td>\n",
       "      <td>28:20</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>3.286754</td>\n",
       "      <td>3.255309</td>\n",
       "      <td>0.391874</td>\n",
       "      <td>28:19</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>3.172497</td>\n",
       "      <td>3.128522</td>\n",
       "      <td>0.406803</td>\n",
       "      <td>28:37</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>3.126867</td>\n",
       "      <td>3.025249</td>\n",
       "      <td>0.419882</td>\n",
       "      <td>28:36</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>3.128793</td>\n",
       "      <td>2.991077</td>\n",
       "      <td>0.424622</td>\n",
       "      <td>28:39</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.unfreeze()\n",
    "learn.fit_one_cycle(10, lr, moms=(0.8,0.7))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "hidden": true
   },
   "source": [
    "Save the pretrained model and vocab:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "mdl_path = path/'models'\n",
    "mdl_path.mkdir(exist_ok=True)\n",
    "learn.to_fp32().save(mdl_path/lm_fns[0], with_opt=False)\n",
    "learn.data.vocab.save(mdl_path/(lm_fns[1] + '.pkl'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Vietnamese sentiment analysis"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Language model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- [Data](https://github.com/ngxbac/aivivn_phanloaisacthaibinhluan/tree/master/data)\n",
    "- [Competition details](https://www.aivivn.com/contests/1)\n",
    "- Top 3 f1 scores: 0.900, 0.897, 0.897"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>comment</th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>train_000000</td>\n",
       "      <td>Dung dc sp tot cam on \\nshop Đóng gói sản phẩm...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>train_000001</td>\n",
       "      <td>Chất lượng sản phẩm tuyệt vời . Son mịn nhưng...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>train_000002</td>\n",
       "      <td>Chất lượng sản phẩm tuyệt vời nhưng k có hộp ...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>train_000003</td>\n",
       "      <td>:(( Mình hơi thất vọng 1 chút vì mình đã kỳ vọ...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>train_000004</td>\n",
       "      <td>Lần trước mình mua áo gió màu hồng rất ok mà đ...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "             id                                            comment  label\n",
       "0  train_000000  Dung dc sp tot cam on \\nshop Đóng gói sản phẩm...      0\n",
       "1  train_000001   Chất lượng sản phẩm tuyệt vời . Son mịn nhưng...      0\n",
       "2  train_000002   Chất lượng sản phẩm tuyệt vời nhưng k có hộp ...      0\n",
       "3  train_000003  :(( Mình hơi thất vọng 1 chút vì mình đã kỳ vọ...      1\n",
       "4  train_000004  Lần trước mình mua áo gió màu hồng rất ok mà đ...      1"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_df = pd.read_csv(path/'train.csv')\n",
    "train_df.loc[pd.isna(train_df.comment),'comment']='NA'\n",
    "train_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>comment</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>test_000000</td>\n",
       "      <td>Chưa dùng thử nên chưa biết</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>test_000001</td>\n",
       "      <td>Không đáng tiềnVì ngay đợt sale nên mới mua n...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>test_000002</td>\n",
       "      <td>Cám ơn shop. Đóng gói sản phẩm rất đẹp và chắc...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>test_000003</td>\n",
       "      <td>Vải đẹp.phom oki luôn.quá ưng</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>test_000004</td>\n",
       "      <td>Chuẩn hàng đóng gói đẹp</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "            id                                            comment\n",
       "0  test_000000                        Chưa dùng thử nên chưa biết\n",
       "1  test_000001   Không đáng tiềnVì ngay đợt sale nên mới mua n...\n",
       "2  test_000002  Cám ơn shop. Đóng gói sản phẩm rất đẹp và chắc...\n",
       "3  test_000003                      Vải đẹp.phom oki luôn.quá ưng\n",
       "4  test_000004                            Chuẩn hàng đóng gói đẹp"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_df = pd.read_csv(path/'test.csv')\n",
    "test_df.loc[pd.isna(test_df.comment),'comment']='NA'\n",
    "test_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.concat([train_df,test_df], sort=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_lm = (TextList.from_df(df, path, cols='comment')\n",
    "    .split_by_rand_pct(0.1, seed=42)\n",
    "    .label_for_lm()           \n",
    "    .databunch(bs=bs, num_workers=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_lm = language_model_learner(data_lm, AWD_LSTM, pretrained_fnames=lm_fns, drop_mult=1.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = 1e-3\n",
    "lr *= bs/48"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>4.975080</td>\n",
       "      <td>4.138585</td>\n",
       "      <td>0.317773</td>\n",
       "      <td>00:07</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>4.408635</td>\n",
       "      <td>4.025489</td>\n",
       "      <td>0.326423</td>\n",
       "      <td>00:07</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn_lm.fit_one_cycle(2, lr*10, moms=(0.8,0.7))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>4.142114</td>\n",
       "      <td>3.928278</td>\n",
       "      <td>0.336230</td>\n",
       "      <td>00:09</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>4.010835</td>\n",
       "      <td>3.793583</td>\n",
       "      <td>0.349972</td>\n",
       "      <td>00:09</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>3.873617</td>\n",
       "      <td>3.694702</td>\n",
       "      <td>0.357240</td>\n",
       "      <td>00:09</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>3.761377</td>\n",
       "      <td>3.632186</td>\n",
       "      <td>0.364648</td>\n",
       "      <td>00:09</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>3.679017</td>\n",
       "      <td>3.595601</td>\n",
       "      <td>0.366964</td>\n",
       "      <td>00:09</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>3.614548</td>\n",
       "      <td>3.576386</td>\n",
       "      <td>0.369224</td>\n",
       "      <td>00:09</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>3.575895</td>\n",
       "      <td>3.567496</td>\n",
       "      <td>0.370285</td>\n",
       "      <td>00:09</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>3.560278</td>\n",
       "      <td>3.566525</td>\n",
       "      <td>0.370173</td>\n",
       "      <td>00:10</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn_lm.unfreeze()\n",
    "learn_lm.fit_one_cycle(8, lr, moms=(0.8,0.7))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_lm.save(f'{lang}fine_tuned')\n",
    "learn_lm.save_encoder(f'{lang}fine_tuned_enc')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_clas = (TextList.from_df(train_df, path, vocab=data_lm.vocab, cols='comment')\n",
    "    .split_by_rand_pct(0.1, seed=42)\n",
    "    .label_from_df(cols='label')\n",
    "    .databunch(bs=bs, num_workers=1))\n",
    "\n",
    "data_clas.save(f'{lang}_textlist_class')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_clas = load_data(path, f'{lang}_textlist_class', bs=bs, num_workers=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import f1_score\n",
    "\n",
    "@np_func\n",
    "def f1(inp,targ): return f1_score(targ, np.argmax(inp, axis=-1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_c = text_classifier_learner(data_clas, AWD_LSTM, drop_mult=0.5, metrics=[accuracy,f1]).to_fp16()\n",
    "learn_c.load_encoder(f'{lang}fine_tuned_enc')\n",
    "learn_c.freeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr=2e-2\n",
    "lr *= bs/48"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>_inner</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>0.338150</td>\n",
       "      <td>0.275298</td>\n",
       "      <td>0.899876</td>\n",
       "      <td>0.878430</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.302302</td>\n",
       "      <td>0.245949</td>\n",
       "      <td>0.902985</td>\n",
       "      <td>0.877226</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn_c.fit_one_cycle(2, lr, moms=(0.8,0.7))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>_inner</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>0.321768</td>\n",
       "      <td>0.255457</td>\n",
       "      <td>0.899254</td>\n",
       "      <td>0.871367</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.305934</td>\n",
       "      <td>0.250888</td>\n",
       "      <td>0.894901</td>\n",
       "      <td>0.872021</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn_c.fit_one_cycle(2, lr, moms=(0.8,0.7))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>_inner</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>0.300939</td>\n",
       "      <td>0.261080</td>\n",
       "      <td>0.893657</td>\n",
       "      <td>0.866201</td>\n",
       "      <td>00:03</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.263790</td>\n",
       "      <td>0.220207</td>\n",
       "      <td>0.906716</td>\n",
       "      <td>0.886115</td>\n",
       "      <td>00:03</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn_c.freeze_to(-2)\n",
    "learn_c.fit_one_cycle(2, slice(lr/(2.6**4),lr), moms=(0.8,0.7))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>_inner</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>0.282888</td>\n",
       "      <td>0.238203</td>\n",
       "      <td>0.905473</td>\n",
       "      <td>0.886483</td>\n",
       "      <td>00:04</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.248599</td>\n",
       "      <td>0.216489</td>\n",
       "      <td>0.918532</td>\n",
       "      <td>0.901550</td>\n",
       "      <td>00:04</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn_c.freeze_to(-3)\n",
    "learn_c.fit_one_cycle(2, slice(lr/2/(2.6**4),lr/2), moms=(0.8,0.7))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>_inner</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>0.201508</td>\n",
       "      <td>0.217176</td>\n",
       "      <td>0.911070</td>\n",
       "      <td>0.890084</td>\n",
       "      <td>00:05</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn_c.unfreeze()\n",
    "learn_c.fit_one_cycle(1, slice(lr/10/(2.6**4),lr/10), moms=(0.8,0.7))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_c.save(f'{lang}clas')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Competition top 3 f1 scores: 0.90, 0.89, 0.89. Winner used an ensemble of 4 models: TextCNN, VDCNN, HARNN, and SARNN."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Ensemble"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_clas = load_data(path, f'{lang}_textlist_class', bs=bs, num_workers=1)\n",
    "learn_c = text_classifier_learner(data_clas, AWD_LSTM, drop_mult=0.5, metrics=[accuracy,f1]).to_fp16()\n",
    "learn_c.load(f'{lang}clas', purge=False);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(0.9111), tensor(0.8952))"
      ]
     },
     "execution_count": 69,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "preds,targs = learn_c.get_preds(ordered=True)\n",
    "accuracy(preds,targs),f1(preds,targs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_clas_bwd = load_data(path, f'{lang}_textlist_class_bwd', bs=bs, num_workers=1, backwards=True)\n",
    "learn_c_bwd = text_classifier_learner(data_clas_bwd, AWD_LSTM, drop_mult=0.5, metrics=[accuracy,f1]).to_fp16()\n",
    "learn_c_bwd.load(f'{lang}clas_bwd', purge=False);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(0.9092), tensor(0.8957))"
      ]
     },
     "execution_count": 70,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "preds_b,targs_b = learn_c_bwd.get_preds(ordered=True)\n",
    "accuracy(preds_b,targs_b),f1(preds_b,targs_b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds_avg = (preds+preds_b)/2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(0.9154), tensor(0.9016))"
      ]
     },
     "execution_count": 72,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "accuracy(preds_avg,targs_b),f1(preds_avg,targs_b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": false,
   "sideBar": true,
   "skip_h1_title": true,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
